# runners/bohb_runner.py
from __future__ import annotations
import time
import numpy as np

import ConfigSpace as CS
import hpbandster.core.nameserver as hpns
from hpbandster.optimizers import BOHB
from hpbandster.core.worker import Worker

from objective import Objective
from loggers import ExperimentLogger

class _SingleFidelityWorker(Worker):
    """
    Single-fidelity Worker:
    - Does not pass ExperimentLogger to Worker's logger
    - Saves as exp_logger for recording CSV
    """
    def __init__(self, *,
                 cs: CS.ConfigurationSpace,
                 obj: Objective,
                 exp_logger: ExperimentLogger,   # ← Renamed to avoid confusion with Worker.logger
                 method_name: str,
                 bench: str,
                 seed: int,
                 **kwargs):
        # Only pass parameters required by HpBandSter to the parent class (nameserver / port / run_id / host / nic_name / logger (optional))
        super().__init__(**kwargs)
        self.cs = cs
        self.obj = obj
        self.exp_logger = exp_logger
        self.method = method_name
        self.bench = bench
        self.seed = seed
        self.best = float("inf")
        self.n_eval = 0

    def compute(self, config, budget, **kwargs):
        self.n_eval += 1
        t0 = time.perf_counter()
        loss, sim_t = self.obj.evaluate(config)
        elapsed = time.perf_counter() - t0

        self.best = min(self.best, loss)
        self.exp_logger.log(dict(
            seed=self.seed,
            method=self.method,
            bench=self.bench,
            n_eval=self.n_eval,
            sim_time=sim_t,              # Accumulated in Objective
            elapsed_time=elapsed,
            best_score=1 - self.best,
            curr_score=1 - loss,
            config=config,
        ))
        return {'loss': float(loss), 'info': {}}

    def get_configspace(self):
        return self.cs


def run_bohb(*,
             seed: int,
             bench: str,
             cs: CS.ConfigurationSpace,
             obj: Objective,
             budget_n: int,
             logger: ExperimentLogger,
             method_name: str = "BOHB-HPBandSter",
             min_budget: float = 1.0,
             max_budget: float = 1.0,
             eta: float = 999999.0):
    """
    Single-fidelity BOHB (min_budget=max_budget=1, eta very large → each round ≈ one evaluation)
    """
    import numpy as np
    import Pyro4
    np.random.seed(seed)
    # Key: switch serializer to avoid serpent failing to handle numpy.bool_ etc.
    Pyro4.config.SERIALIZER = 'pickle'
    Pyro4.config.SERIALIZERS_ACCEPTED = {'pickle', 'serpent', 'json', 'marshal'}

    run_id = f"{bench}_{method_name}_seed{seed}"
    ns = hpns.NameServer(run_id=run_id, host='127.0.0.1', port=0)
    ns_host, ns_port = ns.start()   # Returns host/port

    worker = _SingleFidelityWorker(
        cs=cs, obj=obj, exp_logger=logger,
        method_name=method_name, bench=bench, seed=seed,
        nameserver=ns_host, nameserver_port=ns_port, run_id=run_id
    )
    worker.run(background=True)

    bohb = BOHB(
        configspace=cs,
        min_budget=min_budget,
        max_budget=max_budget,
        eta=eta,
        nameserver=ns_host,
        nameserver_port=ns_port,
        run_id=run_id,
    )

    try:
        bohb.run(n_iterations=budget_n, min_n_workers=1)
    finally:
        bohb.shutdown(shutdown_workers=True)
        try:
            ns.shutdown()
        except Exception:
            pass
